import csv
import json
import os
import numpy as np
from tqdm import tqdm
# from evaluate.evaluate_distrib_rl import generate_game_plot
from generic.data_util import divide_dataset_according2date, judge_home_away, read_feature_mean_scale, \
    reverse_standard_data, ICEHOCKEY_ACTIONS, print_game_events_info
from generic.model_util import to_np
from generic.plot_util import plot_curve


class PlayerRanking:

    def __init__(self, agent,
                 # uncertainty_thresholds=(0.0001, 0.001, 0.01, 0.1, 1),
                 rank_metric,
                 alphas=['mean'],
                 debug_mode=False,
                 mode='test',
                 sanity_check_msg='',
                 uncertainty_model='gda',
                 use_expectation_base=False,
                 log_file=None):
        self.agent = agent
        self.alphas = alphas
        self.log_file = log_file
        self.sanity_check_msg = sanity_check_msg
        self.rank_metric = rank_metric
        self.uncertainty_model = uncertainty_model
        self.use_expectation_base = use_expectation_base

        if 'UGIM' in self.rank_metric:
            self.player_impact_dict_by_alpha = {}
            for alpha in self.alphas:
                self.player_impact_dict_by_alpha.update({alpha: {}})
            # if self.agent.all_gda_models is not None:
            self.player_uncertainty_dict_by_alpha = {}  # record the uncertainty of each impact
            for alpha in self.alphas:
                self.player_uncertainty_dict_by_alpha.update({alpha: {}})
        elif 'GIM' in self.rank_metric:
            self.player_impact_dict = {}
        elif 'EG' in self.rank_metric:
            self.player_expected_goal_dict = {}
        elif 'SI' in self.rank_metric:
            self.player_score_impact_dict = {}
        elif 'VAEP' in self.rank_metric:
            self.player_vaep_dict = {}

        all_files = sorted(os.listdir(self.agent.train_data_path))
        training_files, validation_files, testing_files, split_dates = \
            divide_dataset_according2date(all_data_files=all_files,
                                          sports=self.agent.sports,
                                          train_rate=self.agent.train_rate,
                                          # if_split=False if self.agent.sports == 'soccer' else self.agent.apply_data_date_div,
                                          if_split=self.agent.apply_data_date_div,
                                          if_return_split=True)
        if split_dates is not None:
            if mode == 'train':
                self.game_files = training_files
                self.start_date = split_dates[0]
                self.end_date = split_dates[1]
            elif mode == 'valid':
                self.game_files = validation_files
                self.start_date = split_dates[1]
                self.end_date = split_dates[2]
            elif mode == 'test':
                self.game_files = testing_files
                self.start_date = split_dates[2]
                self.end_date = split_dates[3]
            elif mode == 'all':
                self.game_files = all_files
                self.start_date = split_dates[0]
                self.end_date = split_dates[3]
            else:
                raise ValueError("Unknown running mode {0}".format(mode))
        else:
            self.game_files = all_files

        if debug_mode:
            self.game_files = self.game_files[-2:]

        self.player_stats_dict = {}

        # self.interested_features = ['G', 'GWG', 'OTG', 'SHG', 'PPG', 'P', 'SHP', 'PPP', 'A', 'PIM', 'GP', 'S']

        if self.agent.sports == 'ice_hockey':
            self.interested_measures = ['A', 'G', 'GWG', 'OTG', 'SHG', 'PPG', 'P', 'SHP', 'PPP', 'PIM', 'GP', 'S']
            home_away_ids_dir = '../icehockey-data/game_id_home_away_2018_19.json'
            with open(home_away_ids_dir, 'r') as f:
                self.home_away_game_ids = json.load(f)
            player_summary_dir = '../icehockey-data/NHL_players_game_summary_201819.csv'
            with open(player_summary_dir) as f:
                self.players_gbg_stats = json.load(f)
            game_date_dir = '../icehockey-data/game_dates_2018_2019.json'
            with open(game_date_dir) as f:
                self.all_game_dates = json.load(f)
            self.compute_ice_hockey_players_stats_by_time()
        elif self.agent.sports == 'soccer':
            # TODO: maybe we should set '-' to 0 instead of removing them
            with open('../soccer-data/player_profile_with_alg_values.csv') as player_profile_file:
                player_profile_reader = csv.DictReader(player_profile_file)
                self.player_profile_all = []
                for player_profile in player_profile_reader:
                    p_name = player_profile['playerName']
                    t_name = player_profile['teamName']
                    id = player_profile['playerId']
                    self.player_profile_all.append([p_name, t_name, id])

            self.interested_measures = {'summary': ['Mins', 'Goals', 'Assists', 'Yel', 'Red',
                                                    'SpG', 'PS', 'MotM'],
                                        'defensive': ['Mins', 'Tackles', 'Inter', 'Fouls', 'Offsides',
                                                      'Clear', 'Drb', 'Blocks', 'OwnG'],
                                        'offensive': ['Mins', 'Goals', 'SpG', 'KeyP', 'Drb', 'Fouled',
                                                      'Off', 'Disp', 'UnsTch'],
                                        'passing': ['Mins', 'Assists', 'KeyP', 'AvgP', 'PS', 'Crosses', 'LongB',
                                                    'ThrB']
                                        }
            self.player_summary_dir_dict = \
                {'summary': ['../soccer-data/whoScored/Championship/Championship_summary.csv',
                             '../soccer-data/whoScored/PremierLeague/Premier_League_summary.csv'],
                 'defensive': ['../soccer-data/whoScored/Championship/Championship_defensive.csv',
                               '../soccer-data/whoScored/PremierLeague/Premier_League_defensive.csv'],
                 'offensive': ['../soccer-data/whoScored/Championship/Championship_offensive.csv',
                               '../soccer-data/whoScored/PremierLeague/Premier_League_offensive.csv'],
                 'passing': ['../soccer-data/whoScored/Championship/Championship_passing.csv',
                             '../soccer-data/whoScored/PremierLeague/Premier_League_passing.csv'],
                 }
            game_date_dir = '../soccer-data/game_dates_2017_2018.json'
            with open(game_date_dir) as f:
                self.all_game_dates = json.load(f)
            home_away_ids_dir = '../soccer-data/game_id_home_away_2017_18.json'
            with open(home_away_ids_dir, 'r') as f:
                self.home_away_game_ids = json.load(f)
            self.compute_soccer_players_stats()
        else:
            raise ValueError("Unknown sports: {0}".format(self.agent.sports))

        self.max_uncer = -float('inf')
        self.min_uncer = float('inf')
        # from tqdm import tqdm
        if 'UGIM' in self.rank_metric:
            output_uncertainties_all = []
            for idx in tqdm(range(len(self.game_files))):
                # for idx in tqdm(range(len(self.game_files)), desc="Collecting player impacts."):
                game_name = self.game_files[idx]
                output_uncertainties = self.compute_risk_impact_by_game(game_name)
                output_uncertainties_all.append(output_uncertainties)
            output_uncertainties_all = np.concatenate(output_uncertainties_all)
            output_uncertainties_all.sort()
            self.max_uncer = np.max(output_uncertainties_all)
            self.min_uncer = np.min(output_uncertainties_all)
            self.output_uncertainties_all = output_uncertainties_all
        elif 'GIM' in self.rank_metric:
            for idx in tqdm(range(len(self.game_files))):
                game_name = self.game_files[idx]
                self.compute_impact_by_game(game_name)
        elif 'EG' in self.rank_metric:
            # for idx in range(len(self.game_files)):
            #     game_name = self.game_files[idx]
            #     self.compute_expected_goal_by_game(game_name)
            if self.agent.sports == 'ice-hockey':
                for idx in tqdm(range(len(self.game_files))):
                    game_name = self.game_files[idx]
                    self.compute_score_impact_by_game(game_name, if_expected=True)
        elif 'SI' in self.rank_metric:
            # value_dir = '../../sport-analytic-markov/player_impact/' \
            #             'ice_hockey_player_markov_impact-2021-December-08-.json'
            # with open(value_dir, 'rb') as file:
            #     player_score_impact_dict = json.load(file)
            # for key in player_score_impact_dict:
            #     self.player_score_impact_dict.update({key: player_score_impact_dict[key]})
            if self.agent.sports == 'ice-hockey':
                for idx in tqdm(range(len(self.game_files))):
                    game_name = self.game_files[idx]
                    self.compute_score_impact_by_game(game_name)
        elif 'VAEP' in self.rank_metric:
            pass

        print("max uncertainty: {0} and min uncertainty: {1}.".format(
            self.max_uncer, self.min_uncer), file=log_file, flush=True)

    def compute_ice_hockey_players_stats_by_time(self):
        """reverse the player id and game id"""
        from datetime import datetime as dt
        for player_idx in self.players_gbg_stats.keys():
            player_gbg_stats = self.players_gbg_stats[player_idx]['gbg_summary_list']
            pid = self.players_gbg_stats[player_idx]['id']
            player_cumu_game_stats_start = None
            player_cumu_game_stats_end = None
            for player_gbg_stat in player_gbg_stats:
                game_date = str(player_gbg_stat['game_date'])
                game_date = dt.strptime(game_date, "%Y%m%d")
                if game_date < self.start_date:
                    player_cumu_game_stats_start = np.asarray([player_gbg_stat[feature]
                                                               for feature in self.interested_measures])
                if self.start_date <= game_date <= self.end_date:
                    player_cumu_game_stats_end = np.asarray([player_gbg_stat[feature]
                                                             for feature in self.interested_measures])

            player_cumu_game_stats_start = np.asarray([0 for feature in self.interested_measures]) \
                if player_cumu_game_stats_start is None else player_cumu_game_stats_start
            player_cumu_game_stats_end = player_cumu_game_stats_start \
                if player_cumu_game_stats_end is None else player_cumu_game_stats_end

            player_cumu_game_stats = player_cumu_game_stats_end - player_cumu_game_stats_start
            self.player_stats_dict.update({pid: player_cumu_game_stats})

    def get_soccer_id(self, playername, teamname):
        for info in self.player_profile_all:
            p_name = info[0]
            t_name = info[1]
            if playername in p_name and teamname in t_name:
                return True, info[2]
        return False, ''

    def compute_soccer_players_stats(self):
        self.player_stats_dict_by_metric = {}
        for category in self.player_summary_dir_dict.keys():
            for interest_metric in self.interested_measures[category]:
                self.player_stats_dict_by_metric.update({interest_metric: {}})
            for player_stats_path in self.player_summary_dir_dict[category]:
                with open(player_stats_path) as stats_file:
                    stats_reader = csv.DictReader(stats_file)
                    for player_stats in stats_reader:
                        playername = player_stats['name']
                        teamname = player_stats['team']
                        if teamname[0] == '"':  # the data is very ugly
                            teamname = teamname[1:-1]
                        teamname = teamname.split(',')[0]
                        find_player, id = self.get_soccer_id(playername, teamname)
                        if find_player:
                            for interest_metric in self.interested_measures[category]:
                                stats = player_stats[interest_metric]
                                if stats == '-':
                                    stats = 0
                                self.player_stats_dict_by_metric[interest_metric].update({id: float(stats)})

    def compute_correlations_per_alpha(self, uncertainty_threshold=float('inf'),
                                       apply_uncertainty=True, reverse=False, verbose=False):
        corrcoef_all = []
        corrcoef_strings = []
        for alpha in self.alphas:
            player_rgim_dict = {}
            for pid in self.player_impact_dict_by_alpha[alpha].keys():
                valid_player_impacts = []
                for idx in range(len(self.player_impact_dict_by_alpha[alpha][pid])):
                    # if self.agent.all_gda_models is not None:
                    if apply_uncertainty:
                        if reverse:
                            if self.player_uncertainty_dict_by_alpha[alpha][pid][idx] >= uncertainty_threshold:
                                valid_player_impacts.append(self.player_impact_dict_by_alpha[alpha][pid][idx])
                        else:
                            if self.player_uncertainty_dict_by_alpha[alpha][pid][idx] <= uncertainty_threshold:
                                valid_player_impacts.append(self.player_impact_dict_by_alpha[alpha][pid][idx])
                    else:
                        valid_player_impacts.append(self.player_impact_dict_by_alpha[alpha][pid][idx])
                player_game_impact = np.sum(valid_player_impacts)
                player_rgim_dict.update({pid: player_game_impact})

            if self.agent.sports == 'ice-hockey':
                corrcoef_values, corrcoef_string = self.compute_hockey_correlations(player_rgim_dict, True)
            elif self.agent.sports == 'soccer':
                corrcoef_values, corrcoef_string = self.compute_soccer_correlations(player_rgim_dict, True)
            else:
                raise ValueError("Unknown sports: {0}".format(self.agent.sports))
            corrcoef_all.append(corrcoef_values)
            corrcoef_strings.append(corrcoef_string)
        return corrcoef_all, corrcoef_strings

    def compute_soccer_correlations(self, player_eval_metric_dict, verbose):
        corrcoef_string = ''
        corrcoef_value = []
        for metric in self.player_stats_dict_by_metric.keys():
            player_metric_values = []
            player_stats = []
            for pid in self.player_stats_dict_by_metric[metric].keys():
                if pid in player_eval_metric_dict.keys():
                    player_stats.append(self.player_stats_dict_by_metric[metric][pid])
                    player_metric_values.append(player_eval_metric_dict[pid])
            corrcoef = np.corrcoef(np.asarray(player_metric_values), np.asarray(player_stats))
            print(len(player_metric_values))
            # if np.isnan(corrcoef[0][1]):
            #     print('debug')
            corrcoef_value.append(corrcoef[0][1])
            corrcoef_string += '& {1} '.format(metric, corrcoef[0][1])
            if verbose:
                print('{0}:{1}'.format(metric, corrcoef[0][1]), file=self.log_file, flush=True)
        if verbose:
            print('\n', file=self.log_file, flush=True)
        return corrcoef_value, corrcoef_string

    def compute_hockey_correlations(self, player_eval_metric_dict, verbose):
        player_metric_values = []
        player_stats = []
        corrcoef_value = []
        corrcoef_string = ''
        for pid in player_eval_metric_dict.keys():
            if pid in self.player_stats_dict.keys():
                player_metric_values.append(player_eval_metric_dict[pid])
                player_stats.append(self.player_stats_dict[pid])
        player_metric_values = np.asarray(player_metric_values)
        player_stats = np.asarray(player_stats)
        for fidx in range(len(self.interested_measures)):
            feature = self.interested_measures[fidx]
            corrcoef = np.corrcoef(player_metric_values, player_stats[:, fidx])
            corrcoef_value.append(corrcoef[0][1])
            corrcoef_string += '& {1} '.format(feature, corrcoef[0][1])
            if verbose:
                print('{0}:{1}'.format(feature, corrcoef[0][1]), file=self.log_file, flush=True)
        if verbose:
            print('\n', file=self.log_file, flush=True)
        return corrcoef_value, corrcoef_string

    def load_soccer_player_eg(self):
        soccer_player_eg_dir = '../soccer-data/soccer_player_markov_q-2019July22shot.json'
        with open(soccer_player_eg_dir) as read_file:
            player_eg_dict = json.load(read_file)
        return player_eg_dict

    def load_soccer_player_pm(self):
        soccer_player_pm_dir = '../soccer-data/pm_player_all.json'
        with open(soccer_player_pm_dir) as read_file:
            player_pm_dict = json.load(read_file)
        return player_pm_dict


    def load_soccer_player_pm(self):
        soccer_player_si_dir = '../soccer-data/soccer_player_markov_impact-2019June13.json'
        with open(soccer_player_si_dir) as read_file:
            player_si_dict = json.load(read_file)
        return player_si_dict

    def load_ice_hockey_player_pm(self):
        player_pm_dict = {}
        from datetime import datetime as dt
        for player_idx in self.players_gbg_stats.keys():
            player_gbg_stats = self.players_gbg_stats[player_idx]['gbg_summary_list']
            pid = self.players_gbg_stats[player_idx]['id']
            player_cumu_game_stats_start = None
            player_cumu_game_stats_end = None
            date_start = None
            date_end = None
            for player_gbg_stat in player_gbg_stats:
                game_date = str(player_gbg_stat['game_date'])
                game_date = dt.strptime(game_date, "%Y%m%d")
                if game_date < self.start_date:
                    player_cumu_game_stats_start = np.asarray([player_gbg_stat['+/-']])
                    date_start = game_date
                if self.start_date <= game_date <= self.end_date:
                    player_cumu_game_stats_end = np.asarray([player_gbg_stat['+/-']])
                    date_end = game_date

            player_cumu_game_stats_start = np.asarray([0]) \
                if player_cumu_game_stats_start is None else player_cumu_game_stats_start
            player_cumu_game_stats_end = player_cumu_game_stats_start \
                if player_cumu_game_stats_end is None else player_cumu_game_stats_end

            player_cumu_game_stats = player_cumu_game_stats_end - player_cumu_game_stats_start
            player_pm_dict.update({pid: player_cumu_game_stats[0]})
        return player_pm_dict

    def compute_score_impact_by_game(self, game_name, if_expected=False):
        tids = self.agent.load_team_id(game_name)
        pids = self.agent.load_player_id(game_label=game_name)
        if if_expected:
            value_dir = '../../sport-analytic-markov/resources/ice-hockey-values/' \
                        '{0}-playsequence-wpoi/markov_Qs_values_Iter-10.json'.format(game_name)
        else:
            value_dir = '../../sport-analytic-markov/resources/ice-hockey-values/' \
                        '{0}-playsequence-wpoi/markov_impact_values_Iter-10.json'.format(game_name)
        self.agent.cut_at_goal = True
        self.agent.keep_goal_state = True
        s_a_sequence, r_sequence = self.agent.load_sports_data(game_label=game_name, need_check=False)
        transition_game = self.agent.build_transitions(s_a_data=s_a_sequence,
                                                       r_data=r_sequence,
                                                       pid_sequence=pids)

        with open(value_dir, 'rb') as file:
            si_values = json.load(file)
        for event_idx in range(0, len(tids)):

            state_action_data = transition_game[event_idx].state_action
            state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
                                                        data_means=self.agent.data_means,
                                                        data_stds=self.agent.data_stds)
            action = None
            max_action_label = 0
            for candidate_action in ICEHOCKEY_ACTIONS:
                if state_action_origin[candidate_action] > max_action_label:
                    max_action_label = state_action_origin[candidate_action]
                    action = candidate_action
            si_value = si_values[str(event_idx)]
            print(action)
            # if 'SI' in self.rank_metric:
            #     if 'goal' in action:
            #         continue
            if 'EG' in self.rank_metric:
                if 'carry' not in action:
                    continue
                # if 'assist' in action:
                #     continue
                # if 'goal' in action:
                #     continue
                # if 'shot' in action:
                #     continue
            pid = pids[event_idx]
            home_away_str = judge_home_away(home_away_game_ids=self.home_away_game_ids,
                                            teamId=tids[event_idx],
                                            gameId=game_name)
            # ha_id = 0 if home_away_id == 'home' else 1
            if if_expected:
                if pid in self.player_expected_goal_dict.keys():
                    self.player_expected_goal_dict[pid].append(si_value[home_away_str])
                else:
                    self.player_expected_goal_dict.update({pid: [si_value[home_away_str]]})
            else:
                if pid in self.player_score_impact_dict.keys():
                    self.player_score_impact_dict[pid].append(si_value[home_away_str])
                else:
                    self.player_score_impact_dict.update({pid: [si_value[home_away_str]]})

    def compute_expected_goal_by_game(self, game_name):
        tids = self.agent.load_team_id(game_name)
        output_values, transition_game = self.agent.compute_values_by_game(game_name,
                                                                           self.sanity_check_msg)
        pids = [transition.pid for transition in transition_game]
        for event_idx in range(0, len(transition_game)):
            value = output_values[event_idx]
            pid = pids[event_idx]
            home_away_id = judge_home_away(home_away_game_ids=self.home_away_game_ids,
                                           teamId=tids[event_idx],
                                           gameId=game_name)
            ha_id = 0 if home_away_id == 'home' else 1
            if pid in self.player_expected_goal_dict.keys():
                self.player_expected_goal_dict[pid].append(value[ha_id])
            else:
                self.player_expected_goal_dict.update({pid: [value[ha_id]]})

    def compute_impact_by_game(self, game_name):
        tids = self.agent.load_team_id(game_name)
        output_values, transition_game = self.agent.compute_values_by_game(game_name, self.sanity_check_msg)
        pids = [transition.pid for transition in transition_game]
        dones = [transition.done for transition in transition_game]
        pre_value = output_values[0, :]
        for event_idx in range(1, len(transition_game)):
            value = output_values[event_idx]
            pid = pids[event_idx]
            if pre_value is not None:
                home_away_id = judge_home_away(home_away_game_ids=self.home_away_game_ids,
                                               teamId=tids[event_idx],
                                               gameId=game_name)
                ha_id = 0 if home_away_id == 'home' else 1
                impact = (self.agent.gamma * value - pre_value)[ha_id]
                if pid in self.player_impact_dict.keys():
                    self.player_impact_dict[pid].append(impact)
                else:
                    self.player_impact_dict.update({pid: [impact]})

            if dones[event_idx]:  # episode ends, don't compute impact
                pre_value = None
            else:
                pre_value = value

    def compute_risk_impact_by_game(self, game_name):
        tids = self.agent.load_team_id(game_name)
        output_values, transition_game = self.agent.compute_values_by_game(game_name,
                                                                           self.sanity_check_msg)
        # if self.agent.all_gda_models is not None or :
        output_uncertainties, _ = self.agent.compute_uncertainty_by_game(game_name=game_name,
                                                                         uncertainty_model=self.uncertainty_model,
                                                                         sanity_check_msg=self.sanity_check_msg,
                                                                         transition_game=transition_game)

        # s_a_sequence, r_sequence = self.agent.load_hockey_data(game_label=game_name, need_check=False)
        # pid_sequence = self.agent.load_player_id(game_label=game_name)
        # if self.agent.apply_rnn:
        #     transition_game = self.agent.build_rnn_transitions(s_a_data=s_a_sequence,
        #                                                        r_data=r_sequence,
        #                                                        pid_sequence=pid_sequence)
        # else:
        #     transition_game = self.agent.build_transitions(s_a_data=s_a_sequence,
        #                                                    r_data=r_sequence,
        #                                                    pid_sequence=pid_sequence)
        pids = [transition.pid for transition in transition_game]
        dones = [transition.done for transition in transition_game]
        for alpha in self.alphas:

            if alpha == 'mean':
                pre_risk_value = np.mean(output_values[0], axis=-1)
            else:
                risk_idx = int(self.agent.num_tau * alpha)
                pre_risk_value = output_values[0, :, risk_idx]

            for event_idx in range(1, len(transition_game)):

                if alpha == 'mean':
                    risk_value = np.mean(output_values[event_idx], axis=-1)
                else:
                    risk_idx = int(self.agent.num_tau * alpha)
                    risk_value = output_values[event_idx, :, risk_idx]
                # if self.agent.all_gda_models is not None:
                uncertainty = output_uncertainties[event_idx]

                # data_means, data_stds = read_feature_mean_scale(data_dir='../icehockey-data/')
                # if self.agent.apply_rnn:
                #     state_action_data = transition_game[event_idx].state_action[transition_game[event_idx].trace - 1]
                #     # reward_h = transition_game[event_idx].reward_h[transition_game[event_idx].trace - 1]
                #     # reward_a = transition_game[event_idx].reward_a[transition_game[event_idx].trace - 1]
                #     # reward_n = transition_game[event_idx].reward_n[transition_game[event_idx].trace - 1]
                #     # print(reward_h, reward_a, reward_n)
                #     state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
                #                                                 data_means=data_means,
                #                                                 data_stds=data_stds, )
                #     # sanity_check_msg=sanity_check_msg)
                # else:
                #     state_action_origin = reverse_standard_data(
                #         state_action_data=to_np(transition_game[event_idx].state_action),
                #         data_means=data_means,
                #         data_stds=data_stds, )
                # home_away = 'home' if state_action_origin['home'] > state_action_origin['away'] else 'away'
                # action = None
                # max_action_label = 0
                # for candidate_action in ACTIONS:  # check which action is performed
                #     if state_action_origin[candidate_action] > max_action_label:
                #         max_action_label = state_action_origin[candidate_action]
                #         action = candidate_action
                # if 'goal' in action:
                #     print("find u!")

                if pre_risk_value is not None:
                    home_away_id = judge_home_away(home_away_game_ids=self.home_away_game_ids,
                                                   teamId=tids[event_idx],
                                                   gameId=game_name)
                    pid = pids[event_idx]
                    ha_id = 0 if home_away_id == 'home' else 1
                    impact = (self.agent.gamma * risk_value - pre_risk_value)[ha_id]
                    if pid in self.player_impact_dict_by_alpha[alpha].keys():
                        self.player_impact_dict_by_alpha[alpha][pid].append(impact)
                        # if self.agent.all_gda_models is not None:
                        self.player_uncertainty_dict_by_alpha[alpha][pid].append(uncertainty)
                    else:
                        self.player_impact_dict_by_alpha[alpha].update({pid: [impact]})
                        # if self.agent.all_gda_models is not None:
                        self.player_uncertainty_dict_by_alpha[alpha].update({pid: [uncertainty]})
                if dones[event_idx]:  # episode ends, don't compute impact
                    pre_risk_value = None
                else:
                    if self.use_expectation_base:
                        pre_risk_value = np.mean(output_values[event_idx], axis=-1)
                    else:
                        pre_risk_value = risk_value

        return output_uncertainties


def run_comparison_player_evaluation(agent, rank_metric,
                                     model_save_path=None, iteration=None,
                                     log_file=None, mode='test', sanity_check_msg='', debug_mode=False, debug_msg=None):
    rpr = PlayerRanking(agent=agent,
                        rank_metric=rank_metric,
                        mode=mode,
                        sanity_check_msg=sanity_check_msg,
                        debug_mode=debug_mode,
                        use_expectation_base=agent.use_expectation_base,
                        log_file=log_file)
    measure_strings = []
    [measure_strings.append('& {0}'.format(measure)) for measure in rpr.player_stats_dict_by_metric.keys()]
    if 'PM' in rank_metric:
        if agent.sports == 'ice-hockey':
            player_pm_dict = rpr.load_ice_hockey_player_pm()
            corrcoef_values, corrcoef_string = rpr.compute_hockey_correlations(player_pm_dict, True)
        elif agent.sports == 'soccer':
            player_pm_dict = rpr.load_soccer_player_pm()
            corrcoef_values, corrcoef_string = rpr.compute_soccer_correlations(player_pm_dict, True)
        else:
            raise ValueError("Unknown sports: {0}".format(agent.sports))
        corrcoef_string = ' '.join(measure_strings) + '\n' + corrcoef_string
        print(corrcoef_string, file=log_file, flush=True)
    elif 'GIM' in rank_metric:
        player_gim_dict = rpr.player_impact_dict
        player_gim_sum_dict = {}
        for pid in player_gim_dict.keys():
            player_gim_sum_dict.update({pid: np.sum(player_gim_dict[pid])})
        if agent.sports == 'ice-hockey':
            corrcoef_values, corrcoef_string = rpr.compute_hockey_correlations(player_gim_sum_dict, True)
        elif agent.sports == 'soccer':
            corrcoef_values, corrcoef_string = rpr.compute_soccer_correlations(player_gim_sum_dict, True)
        else:
            raise ValueError("Unknown sports: {0}".format(agent.sports))
        corrcoef_string = ' '.join(measure_strings) + '\n' + corrcoef_string
        print(corrcoef_string, file=log_file, flush=True)
        model_label = debug_msg + 'correlations_gim' + model_save_path.split('saved')[-1]
        if not os.path.exists('./correlation_results/' + model_label):
            os.mkdir('./correlation_results/' + model_label)
        if not os.path.exists('./correlation_results/' + model_label + '/Iter-' + str(iteration)):
            os.mkdir('./correlation_results/' + model_label + '/Iter-' + str(iteration))
        with open('./correlation_results/' + model_label + '/Iter-' + str(iteration) +
                  '/correl_test_gim.txt', 'w') as f:
            f.write(corrcoef_string + '\n')
    elif 'EG' in rank_metric:
        if agent.sports == 'ice-hockey':
            player_eg_dict = rpr.player_expected_goal_dict
            player_eg_sum_dict = {}
            for pid in player_eg_dict.keys():
                player_eg_sum_dict.update({pid: np.sum(player_eg_dict[pid])})
            corrcoef_values, corrcoef_string = rpr.compute_hockey_correlations(player_eg_sum_dict, True)
        elif agent.sports == 'soccer':
            player_eg_sum_dict = rpr.load_soccer_player_eg()
            corrcoef_values, corrcoef_string = rpr.compute_soccer_correlations(player_eg_sum_dict, True)
        else:
            raise ValueError("Unknown sports: {0}".format(agent.sports))
        corrcoef_string = ' '.join(measure_strings) + '\n' + corrcoef_string
        print(corrcoef_string, file=log_file, flush=True)
        model_label = debug_msg + 'correlations_eg' + model_save_path.split('saved')[-1]
        if not os.path.exists('./correlation_results/' + model_label):
            os.mkdir('./correlation_results/' + model_label)
        if not os.path.exists('./correlation_results/' + model_label + '/Iter-' + str(iteration)):
            os.mkdir('./correlation_results/' + model_label + '/Iter-' + str(iteration))
        with open('./correlation_results/' + model_label + '/Iter-' + str(iteration) +
                  '/correl_test_eg.txt', 'w') as f:
            f.write(corrcoef_string + '\n')
    elif 'SI' in rank_metric:
        if agent.sports == 'ice-hockey':
            player_si_dict = rpr.player_score_impact_dict
            player_si_sum_dict = {}
            for pid in player_si_dict.keys():
                player_si_sum_dict.update({pid: np.sum(player_si_dict[pid])})
            corrcoef_values, corrcoef_string = rpr.compute_hockey_correlations(player_si_sum_dict, True)
        elif agent.sports == 'soccer':
            player_si_sum_dict = rpr.load_soccer_player_pm()
            corrcoef_values, corrcoef_string = rpr.compute_soccer_correlations(player_si_sum_dict, True)
        else:
            raise ValueError("Unknown sports: {0}".format(agent.sports))
        corrcoef_string = ' '.join(measure_strings) + '\n' + corrcoef_string
        print(corrcoef_string, file=log_file, flush=True)
        model_label = debug_msg + 'correlations_si'
        if not os.path.exists('./correlation_results/' + model_label):
            os.mkdir('./correlation_results/' + model_label)
        if not os.path.exists('./correlation_results/' + model_label + '/Iter-' + str(iteration)):
            os.mkdir('./correlation_results/' + model_label + '/Iter-' + str(iteration))
        with open('./correlation_results/' + model_label + '/Iter-' + str(iteration) +
                  '/correl_test_si.txt', 'w') as f:
            f.write(corrcoef_string + '\n')
    elif 'VAEP' in rank_metric:
        corrcoef_values_all = []
        if agent.sports == 'ice-hockey':
            for epoch_num in [5, 6, 7, 8, 9]:
                loaf_vaep_dir = '../../VAEP_hockey/train_neural_net/players_value_epoch{0}.json'.format(epoch_num)
                with open(loaf_vaep_dir, 'r') as read_file:
                    player_vaep_sum_dict = json.load(read_file)
                corrcoef_values, corrcoef_string = rpr.compute_hockey_correlations(player_vaep_sum_dict, True)
                corrcoef_values_all.append(corrcoef_values)
        elif agent.sports == 'soccer':
            loaf_vaep_dir = '../soccer-data/player_profile_with_alg_values.csv'
            with open(loaf_vaep_dir, 'r') as read_file:
                player_vaep_records = read_file.readlines()
            player_vaep_sum_dict = {}
            for player_vaep_record in player_vaep_records[1:]:
                player_id = player_vaep_record.split(",")[0]
                player_vaep_value = float(player_vaep_record.split(",")[-1])
                player_vaep_sum_dict.update({player_id: player_vaep_value})
            corrcoef_values, corrcoef_string = rpr.compute_soccer_correlations(player_vaep_sum_dict, True)
            corrcoef_values_all.append(corrcoef_values)
        else:
            raise ValueError("Unknown sports: {0}".format(agent.sports))
        corrcoef_values_all = np.asarray(corrcoef_values_all)
        corrcoef_values_mean = np.mean(corrcoef_values_all, axis=0)
        corrcoef_values_std = np.std(corrcoef_values_all, axis=0)

        corrcoef_string = ''
        corrcoef_string_mean_std = ''
        for idx in range(len(rpr.interested_measures)):
            corrcoef_string += '& {0} '.format(round(corrcoef_values_mean[idx], 3))
            corrcoef_string_mean_std += '& {0} $\pm$ {1}'.format(round(corrcoef_values_mean[idx], 3),
                                                                 round(corrcoef_values_std[idx], 3))
        corrcoef_string = ' '.join(measure_strings) + '\n' + corrcoef_string
        print(corrcoef_string, file=log_file, flush=True)
        print(corrcoef_string_mean_std, file=log_file, flush=True)
        model_label = debug_msg + 'correlations_si'
        if not os.path.exists('./correlation_results/' + model_label):
            os.mkdir('./correlation_results/' + model_label)
        if not os.path.exists('./correlation_results/' + model_label + '/Iter-' + str(iteration)):
            os.mkdir('./correlation_results/' + model_label + '/Iter-' + str(iteration))
        with open('./correlation_results/' + model_label + '/Iter-' + str(iteration) +
                  '/correl_test_vaep.txt', 'w') as f:
            f.write(corrcoef_string + '\n')


def run_risk_sensitive_player_evaluation(agent, model_save_path, iteration,
                                         uncertainty_model='gda',
                                         gda_fitting_target="QValues",
                                         mode='test',
                                         log_file=None,
                                         sanity_check_msg='',
                                         debug_mode=False,
                                         debug_msg='',
                                         compute_correlations=True):
    alphas = [round(0.01 * i, 2) for i in range(3, 97)]
    # print(alphas)
    alphas += ['mean']
    uncertainty_msg = ''
    if agent.gda_apply_pd and uncertainty_model == 'gda':
        uncertainty_msg += '_pd'
    if agent.gda_apply_history:
        uncertainty_msg += '_history'

    if uncertainty_model == 'gda':
        model_label = debug_msg + 'correlations_{0}_gda_tgt-{1}'.format(agent.sports, agent.gda_fitting_target, ) + \
                      '_discret-' + agent.gda_discret_mode + uncertainty_msg + model_save_path.split('saved')[-1]
    elif uncertainty_model == 'maf':
        model_label = debug_msg + 'correlations_{0}_maf'.format(agent.sports) + uncertainty_msg + \
                      model_save_path.split('maf')[-1]
    else:
        raise ValueError("Unknown uncertainty model {0}.".format(uncertainty_model))
    if not os.path.exists('./correlation_results/' + model_label):
        os.mkdir('./correlation_results/' + model_label)
    if not os.path.exists('./correlation_results/' + model_label + '/Iter-' + str(iteration)):
        os.mkdir('./correlation_results/' + model_label + '/Iter-' + str(iteration))

    if uncertainty_model == 'gda':
        gda_model_save_mother_dir = '../save_model/gda/{0}gda_{1}_tgt-{2}_discret-{3}{4}{5}'.format(
            debug_msg,
            agent.sports,
            gda_fitting_target,
            agent.gda_discret_mode,
            uncertainty_msg,
            model_save_path.split('model')[-1]
        )
        if not os.path.exists(gda_model_save_mother_dir):
            os.mkdir(gda_model_save_mother_dir)
        gda_model_save_dir = '{0}/save_model_iter-{4}'.format(
            gda_model_save_mother_dir,
            debug_msg,
            gda_fitting_target,
            model_save_path.split('model')[-1],
            iteration,
        )
        # train the gda
        agent.fit_gda(
            gda_fitting_target=gda_fitting_target,
            debug_mode=debug_mode,
            sanity_check_msg=sanity_check_msg,
            log_file=log_file,
        )
        agent.save_gda(gda_model_path=gda_model_save_dir)
        agent.load_gda(gda_model_path=gda_model_save_dir)
    elif uncertainty_model == 'maf':
        pass
    else:
        raise ValueError("Unknown uncertainty model {0}".format(uncertainty_model))

    # Print the mean and uncertainty values of a game
    if agent.sports == 'ice-hockey':
        game_name = '16760'
    elif agent.sports == 'soccer':
        game_name = '917826'
    else:
        raise ValueError("Unknown sports {0}".format(agent.sports))
    output_game, transition_game = agent.compute_values_by_game(game_name, sanity_check_msg)
    # if agent.all_gda_models is not None:
    uncertainties_game, _ = agent.compute_uncertainty_by_game(game_name,
                                                              sanity_check_msg,
                                                              transition_game=transition_game,
                                                              uncertainty_model=uncertainty_model
                                                              )
    # else:
    #     uncertainties_game = None
    output_mean_all = np.mean(output_game, axis=-1)

    with open('./correlation_results/' + model_label + '/Iter-' + str(iteration)
              + '/game-{0}-uncertainty_example.txt'.format(game_name), 'w') as write_file:
        print_game_events_info(transition_game=transition_game,
                               team_values_all=output_mean_all,
                               apply_rnn=agent.apply_rnn,
                               team_uncertainties_all=uncertainties_game,
                               sanity_check_msg=sanity_check_msg,
                               write_file=write_file,
                               sports=agent.sports)

    rpr = PlayerRanking(agent=agent,
                        rank_metric='UGIM',
                        alphas=alphas,
                        mode=mode,
                        sanity_check_msg=sanity_check_msg,
                        uncertainty_model=uncertainty_model,
                        use_expectation_base=agent.use_expectation_base,
                        debug_mode=debug_mode,
                        log_file=log_file)
    split_num = 7
    uncertainty_thresholds = [round(rpr.min_uncer + (rpr.max_uncer - rpr.min_uncer) / 20 * i, 2) for i in
                              range(1, split_num)]
    uncertainty_threshold_indices = [int(len(rpr.output_uncertainties_all) / split_num * i) for i in
                                     range(1, split_num)]
    uncertainty_thresholds += [round(rpr.output_uncertainties_all[idx], 2) for idx in uncertainty_threshold_indices]
    uncertainty_thresholds.sort()
    uncertainty_thresholds.append(float('inf'))
    # uncertainty_thresholds = (0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0, float('inf'))
    agent.uncertainty_thresholds = uncertainty_thresholds
    measure_strings = []
    [measure_strings.append('& {0}'.format(measure)) for measure in rpr.player_stats_dict_by_metric.keys()]
    if not compute_correlations:
        return rpr
    else:
        plot_all_features_correl_dict = {}
        plot_all_features_x_dict = {}
        return_correlations = {}
        if agent.sports == 'ice-hockey':
            all_measures = rpr.interested_measures
        elif agent.sports == 'soccer':
            all_measures = list(rpr.player_stats_dict_by_metric.keys())
        else:
            raise ValueError("Unknown sports {0}".format(agent.sports))
        for measure in all_measures:
            plot_all_features_correl_dict.update({measure: {}})
            plot_all_features_x_dict.update({measure: {}})

        for uncertainty_threshold in uncertainty_thresholds:
            corrcoef_all, corrcoef_strings = \
                rpr.compute_correlations_per_alpha(uncertainty_threshold=uncertainty_threshold,
                                                   reverse=False,
                                                   verbose=False)
            # for alpha_idx in range(len(rpr.alphas)):
            #     alpha = rpr.alphas[alpha_idx]
            #     print(alpha, np.mean(corrcoef_all[alpha_idx]))
            return_correlations.update({"{0}".format(uncertainty_threshold): corrcoef_all[-1]})
            with open('./correlation_results/' + model_label + '/Iter-' + str(iteration) +
                      '/correl_test_ssm_threshold-{0}.txt'.format(
                          uncertainty_threshold), 'w') as f:
                f.write(' '.join(measure_strings) + '\n')
                for alpha_idx in range(len(rpr.alphas)):
                    alpha = rpr.alphas[alpha_idx]
                    corrcoef_string = corrcoef_strings[alpha_idx]
                    f.write(str(alpha) + ' : ' + corrcoef_string + '\n')

            for measure_idx in range(len(all_measures)):
                measure = all_measures[measure_idx]
                plot_key = 'AUGIM-{0}'.format(uncertainty_threshold)
                plot_all_features_correl_dict[measure].update({plot_key: []})
                plot_all_features_x_dict[measure].update({plot_key: []})
                for alpha_idx in range(len(rpr.alphas) - 1):
                    confidence = 1 - rpr.alphas[alpha_idx]
                    plot_all_features_correl_dict[measure][plot_key].append(corrcoef_all[alpha_idx][measure_idx])
                    plot_all_features_x_dict[measure][plot_key].append(confidence)

            # if uncertainty_threshold != float('inf'):
            #     corrcoef_all_reverse, corrcoef_strings_reverse = \
            #         rpr.compute_correlations_per_alpha(uncertainty_threshold=uncertainty_threshold, reverse=True)
            #     return_correlations.update({"reverse-{0}".format(uncertainty_threshold): corrcoef_all_reverse[-1]})
            #     with open('./correlation_results/' + model_label + '/Iter-' + str(iteration) +
            #               '/correl_test_ssm_reverse_threshold-{0}.txt'.format(
            #                   uncertainty_threshold), 'w') as f:
            #         for alpha_idx in range(len(rpr.alphas)):
            #             alpha = rpr.alphas[alpha_idx]
            #             corrcoef_string = corrcoef_strings_reverse[alpha_idx]
            #             f.write(str(alpha) + ' : ' + corrcoef_string + '\n')
            #
            #     for feature_idx in range(len(rpr.interested_features)):
            #         feature = rpr.interested_features[feature_idx]
            #         plot_key = 'AUGIM-reverse-{0}'.format(uncertainty_threshold)
            #         plot_all_features_correl_dict[feature].update({plot_key: []})
            #         plot_all_features_x_dict[feature].update({plot_key: []})
            #         for alpha_idx in range(len(rpr.alphas) - 1):
            #             confidence = 1 - rpr.alphas[alpha_idx]
            #             plot_all_features_correl_dict[feature][plot_key].append(
            #                 corrcoef_all_reverse[alpha_idx][feature_idx])
            #             plot_all_features_x_dict[feature][plot_key].append(confidence)

        for measure_idx in range(len(all_measures)):
            measure = all_measures[measure_idx]
            plot_correl_dict = plot_all_features_correl_dict[measure]
            plot_x_dict = plot_all_features_x_dict[measure]
            draw_keys = plot_correl_dict.keys()
            plot_curve(draw_keys=draw_keys,
                       x_dict=plot_x_dict,
                       y_dict=plot_correl_dict,
                       # xlabel='Confidence',
                       # ylabel='Correlation',
                       plot_name='./correlation_results/' +
                                 model_label + '/Iter-' + str(iteration) +
                                 '/plot_correl_feature-{0}'.format(measure),
                       legend_size=8,
                       apply_rainbow=True
                       )

        return return_correlations, model_label
